{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "importing Jupyter notebook from helper_fxns.ipynb\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using gpu device 2: TITAN X (Pascal) (CNMeM is disabled, cuDNN 5105)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "importing Jupyter notebook from run_dir.ipynb\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "from nbfinder import NotebookFinder\n",
    "sys.meta_path.append(NotebookFinder())\n",
    "from helper_fxns import *\n",
    "from run_dir import *\n",
    "from lasagne.nonlinearities import *\n",
    "from lasagne.init import *\n",
    "import argparse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[None, None, None, None, None, None, None, None, None]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_args = dict(data_format_args = {'input_shape': (None,16,768,1152),\n",
    "                    '3d_time_steps_per_example': 8,\n",
    "                    \"im_dim\": 3,\n",
    "                    \"variables\": ['PRECT','PS','PSL','QREFHT','T200','T500','TMQ','TREFHT',\n",
    "                                  'TS','U850','UBOT','V850','VBOT','Z1000','Z200','ZBOT'],\n",
    "                    \"xdim\": 768,\n",
    "                    \"ydim\": 1152,\n",
    "                    \"time_step_sample_frequency\": 2,\n",
    "                    \"time_steps_per_file\": 8\n",
    "    \n",
    "},\n",
    "\n",
    "\n",
    "                \n",
    "                \n",
    "\n",
    "label_format_args = {  'num_classes': 4,\n",
    "                       'box_sizes':[(64,64)],\n",
    "                       'scale_factor': 64,\n",
    "                     \n",
    "    \n",
    "},\n",
    "\n",
    "\n",
    "tr_val_test_args = {'batch_size' : 1,\n",
    "                    'num_test_days':365,\n",
    "                    'num_tr_days': 365,\n",
    "                    'tr_years': [1979],\n",
    "                    'val_years': [1980],\n",
    "                    \"test_years\" : [1984],\n",
    "                    'shuffle': False,\n",
    "                    'epochs': 10000,\n",
    "                    \"test\": False,\n",
    "                   },\n",
    "\n",
    "file_args = {'metadata_dir': \"/home/evan/data/climate/labels/\",\n",
    "             'data_dir': \"/home/evan/data/climate/input\",\n",
    "\n",
    "             \"max_files_open\": 1,\n",
    "             \n",
    "    \n",
    "},\n",
    "\n",
    "\n",
    "opt_args = { 'learning_rate': 0.0001,\n",
    "             'weight_decay': 0.0005, \n",
    "             'lambda_ae' : 10,          \n",
    "             'coord_penalty': 5,\n",
    "             'size_penalty': 7,\n",
    "             'nonobj_penalty': 0.5,\n",
    "             \"batch_norm\" : False,\n",
    "             'dropout_p': 0, \n",
    "             \"yolo_batch_norm\" : True,\n",
    "    \n",
    "},\n",
    "\n",
    "arch_args = {\n",
    "             \"filters_scale\" : 1.,\n",
    "            \"filter_dim\" : 5, \"num_layers\": 6\n",
    "\n",
    "},\n",
    "\n",
    "\n",
    "eval_args = {\n",
    "    'iou_thresh' : 0.5,\n",
    "\n",
    "    \n",
    "},\n",
    "\n",
    "\n",
    "save_load_args = { 'save_weights': True,\n",
    "                   \"yolo_load_path\": \"None\",\n",
    "                   \"ae_load_path\": \"None\",\n",
    "                   \"save_path\":\"./results\",\n",
    "                  \"save_results\": True,\n",
    "},\n",
    "\n",
    "\n",
    "plotting_args = {  \"num_ims_to_plot\" : 8,\n",
    "\n",
    "                  \"get_fmaps\": False,\n",
    "\n",
    "\n",
    "\n",
    "                  \"no_plots\": False,\n",
    "\n",
    "                  \"get_ims\": False\n",
    "                },\n",
    "    )\n",
    "\n",
    "\n",
    "default_args = {}\n",
    "[default_args.update(d) for d in all_args.values() ]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def process_kwargs(save_res=True):\n",
    "    args= parse_cla()\n",
    "    kwargs = default_args\n",
    "    \n",
    "    kwargs.update(args)\n",
    "\n",
    "\n",
    "    if kwargs[\"save_results\"]:\n",
    "        run_dir = create_run_dir(save_path)\n",
    "        kwargs['save_path'] = run_dir\n",
    "        dump_hyperparams(kwargs, run_dir)\n",
    "\n",
    "        kwargs[\"logger\"] = setup_logging(kwargs['save_path'])\n",
    "    \n",
    "    return kwargs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def parse_cla():\n",
    "    # if inside a notebook, then get rid of weird notebook arguments, so that arg parsing still works\n",
    "    if any([\"jupyter\" in arg for arg in sys.argv]):\n",
    "        sys.argv=sys.argv[:1]\n",
    "        #default_args.update({\"num_layers\": 6, \"num_test_days\":3,\"ignore_plot_fails\":0, \"test\":False, \"no_plots\":True, \"num_filters\": 2, \"filters_scale\": 0.01, \"num_tr_days\":3, \"lambda_ae\":0})\n",
    "\n",
    "\n",
    "    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)\n",
    "    for k,v in default_args.iteritems():\n",
    "        \n",
    "        if k is not \"variables\":\n",
    "            if type(v) is list:\n",
    "                parser.add_argument('--' + k, type=type(v[0]),nargs='+', default=v, help=k)\n",
    "            elif type(v) is bool:\n",
    "                parser.add_argument('--' + k, action='store_true', help=k)\n",
    "            else:   \n",
    "                parser.add_argument('--' + k, type=type(v), default=v, help=k)\n",
    "\n",
    "    args = parser.parse_args()\n",
    "    return args.__dict__"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
